# +----------------------------------------------------------------------
# | Pink [ A modern python web framework ]
# +----------------------------------------------------------------------
# | Copyright (c) 2023 http:#unnnnn.com All rights reserved.
# +----------------------------------------------------------------------
# | Author: chenjianhua <unnnnn@foxmail.com>
# +----------------------------------------------------------------------

import os
from os import listdir
from os.path import isfile, join
from pydoc import locate
from timeit import default_timer as timer

from inflection import camelize

from ..config import load_config
from ..models.MigrationModel import MigrationModel
from ..schema import Schema


class Migration:
    def __init__(
        self,
        connection="default",
        dry=False,
        command_class=None,
        migration_directory="databases/migrations",
        config_path=None,
        schema=None,
    ):
        self.connection = connection
        self.migration_directory = migration_directory
        self.last_migrations_ran = []
        self.command_class = command_class

        self.schema_name = schema

        DB = load_config(config_path).DB

        DATABASES = DB.get_connection_details()

        self.schema = Schema(
            connection=connection,
            connection_details=DATABASES,
            dry=dry,
            schema=self.schema_name,
        )

        self.migration_model = MigrationModel.on(self.connection)
        if self.schema_name:
            self.migration_model.set_schema(self.schema_name)

    def create_table_if_not_exists(self):
        if not self.schema.has_table("migrations"):
            with self.schema.create("migrations") as table:
                table.increments("migration_id")
                table.string("migration")
                table.integer("batch")

            return True

        return False

    def get_unran_migrations(self):
        directory_path = os.path.join(os.getcwd(), self.migration_directory)
        all_migrations = [
            f.replace(".py", "")
            for f in listdir(directory_path)
            if isfile(join(directory_path, f)) and f != "__init__.py" and not f.startswith('.')
        ]
        all_migrations.sort()
        unran_migrations = []
        database_migrations = self.migration_model.all()
        for migration in all_migrations:
            if migration not in database_migrations.pluck("migration"):
                unran_migrations.append(migration)
        return unran_migrations

    def get_rollback_migrations(self):
        return (
            self.migration_model.where("batch", self.migration_model.all().max("batch"))
            .order_by("migration_id", "desc")
            .get()
            .pluck("migration")
        )

    def get_all_migrations(self, reverse=False):
        if reverse:
            return (
                self.migration_model.order_by("migration_id", "desc")
                .get()
                .pluck("migration")
            )

        return self.migration_model.all().pluck("migration")

    def get_last_batch_number(self):
        return self.migration_model.select("batch").get().max("batch")

    def delete_migration(self, file_path):
        return self.migration_model.where("migration", file_path).delete()

    def locate(self, file_name):
        migration_name = camelize("_".join(file_name.split("_")[4:]).replace(".py", ""))
        file_name = file_name.replace(".py", "")
        migration_directory = self.migration_directory.replace("/", ".").replace(
            "\\", "."
        )
        return locate(f"{migration_directory}.{file_name}.{migration_name}")

    def get_ran_migrations(self):
        directory_path = os.path.join(os.getcwd(), self.migration_directory)
        all_migrations = [
            f.replace(".py", "")
            for f in listdir(directory_path)
            if isfile(join(directory_path, f)) and f != "__init__.py" and not f.startswith('.')
        ]
        all_migrations.sort()
        ran = []

        database_migrations = self.migration_model.all()
        for migration in all_migrations:
            if migration in database_migrations.pluck("migration"):
                ran.append(migration)
        return ran

    def migrate(self, migration="all", output=False):

        default_migrations = self.get_unran_migrations()
        migrations = default_migrations if migration == "all" else [migration]

        batch = self.get_last_batch_number() + 1

        for migration in migrations:

            try:
                migration_class = self.locate(migration)

            except TypeError:
                self.command_class.line(f"<error>Not Found: {migration}</error>")
                continue

            self.last_migrations_ran.append(migration)
            if self.command_class:
                self.command_class.line(
                    f"<comment>Migrating:</comment> <question>{migration}</question>"
                )

            migration_class = migration_class(
                connection=self.connection, schema=self.schema_name
            )

            if output:
                migration_class.schema.dry()
            start = timer()
            migration_class.up()
            duration = "{:.2f}".format(timer() - start)

            if output:
                if self.command_class:
                    table = self.command_class.table()
                    table.set_header_row(["SQL"])
                    sql = migration_class.schema._blueprint.to_sql()
                    if isinstance(sql, list):
                        sql = ",".join(sql)
                    table.set_rows([[sql]])
                    table.render(self.command_class.io)
                    continue
                else:
                    print(migration_class.schema._blueprint.to_sql())

            if self.command_class:
                self.command_class.line(
                    f"<info>Migrated:</info> <question>{migration}</question> ({duration}s)"
                )

            self.migration_model.create(
                {"batch": batch, "migration": migration.replace(".py", "")}
            )

    def rollback(self, migration="all", output=False):

        default_migrations = self.get_rollback_migrations()
        migrations = default_migrations if migration == "all" else [migration]

        for migration in migrations:
            if migration.endswith(".py"):
                migration = migration.replace(".py", "")

            if self.command_class:
                self.command_class.line(
                    f"<comment>Rolling back:</comment> <question>{migration}</question>"
                )

            try:
                migration_class = self.locate(migration)
            except TypeError:
                self.command_class.line(f"<error>Not Found: {migration}</error>")
                continue

            migration_class = migration_class(
                connection=self.connection, schema=self.schema_name
            )

            if output:
                migration_class.schema.dry()

            start = timer()
            migration_class.down()
            duration = "{:.2f}".format(timer() - start)

            if output:
                if self.command_class:
                    table = self.command_class.table()
                    table.set_header_row(["SQL"])
                    if (
                        hasattr(migration_class.schema, "_blueprint")
                        and migration_class.schema._blueprint
                    ):
                        sql = migration_class.schema._blueprint.to_sql()
                        if isinstance(sql, list):
                            sql = ",".join(sql)

                        table.set_rows([[sql]])
                    elif migration_class.schema._sql:
                        table.set_rows([[migration_class.schema._sql]])
                    table.render(self.command_class.io)
                    continue
                else:
                    print(migration_class.schema._blueprint.to_sql())

            self.delete_migration(migration)

            if self.command_class:
                self.command_class.line(
                    f"<info>Rolled back:</info> <question>{migration}</question> ({duration}s)"
                )

    def delete_migrations(self, migrations=None):
        return self.migration_model.where_in("migration", migrations or []).delete()

    def delete_last_batch(self):
        return self.migration_model.where(
            "batch", self.get_last_batch_number()
        ).delete()

    def reset(self, migration="all"):
        default_migrations = self.get_all_migrations(reverse=True)
        migrations = default_migrations if migration == "all" else [migration]

        if not len(migrations):
            if self.command_class:
                self.command_class.line("<info>Nothing to reset</info>")
            else:
                print("Nothing to reset")

        for migration in migrations:
            if self.command_class:
                self.command_class.line(
                    f"<comment>Rolling back:</comment> <question>{migration}</question>"
                )

            try:
                self.locate(migration)(
                    connection=self.connection, schema=self.schema_name
                ).down()
            except TypeError:
                self.command_class.line(f"<error>Not Found: {migration}</error>")
                continue

                # raise MigrationNotFound(f"Could not find {migration}")

            self.delete_migration(migration)

            if self.command_class:
                self.command_class.line(
                    f"<info>Rolled back:</info> <question>{migration}</question>"
                )

            self.delete_migrations([migration])

        if self.command_class:
            self.command_class.line("")

    def refresh(self, migration="all"):
        self.reset(migration)
        self.migrate(migration)
